{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "X2Fj4S3r0p1A"
      },
      "source": [
        "##### Copyright 2019 Google LLC.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "Okg-R95R1CaX"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "eeObMQsw1R_p"
      },
      "source": [
        "# Object pose alignment\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/6dof_alignment.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/6dof_alignment.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6S5QwOCcgwGI"
      },
      "source": [
        "Precisely estimating the pose of objects is fundamental to many industries. For instance, in augmented and virtual reality, it allows users to modify the state of some variable by interacting with these objects (e.g. volume controlled by a mug on the user's desk).\n",
        "\n",
        "This notebook illustrates how to use [Tensorflow Graphics](https://github.com/tensorflow/graphics) to estimate the rotation and translation of known 3D objects. \n",
        "![](https://storage.googleapis.com/tensorflow-graphics/notebooks/6dof_pose/task.jpg)\n",
        "\n",
        "\n",
        "\n",
        "This capability is illustrated by two different demos:\n",
        "1. **Machine learning** demo illustrating how to train a simple neural network capable of precisely estimating the rotation and translation of a given object with respect to a reference pose.\n",
        "2. **Mathematical optimization** demo that takes a different approach to the problem; does not use machine learning.\n",
        "\n",
        "**Note**: The easiest way to use this tutorial is as a Colab notebook, which allows you to dive in with no setup."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PNQ29y8Q4_cH"
      },
      "source": [
        "## Setup & Imports\n",
        "If Tensorflow Graphics is not installed on your system, the following cell can install the Tensorflow Graphics package for you."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "26AvKq8MJRGl"
      },
      "outputs": [],
      "source": [
        "!pip install tensorflow_graphics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "UkPKOuyJKuKM"
      },
      "source": [
        "Now that Tensorflow Graphics is installed, let's import everything needed to run the demos contained in this notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "KlBviBxue7n0"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "\n",
        "from tensorflow_graphics.geometry.transformation import quaternion\n",
        "from tensorflow_graphics.math import vector\n",
        "from tensorflow_graphics.notebooks import threejs_visualization\n",
        "from tensorflow_graphics.notebooks.resources import tfg_simplified_logo\n",
        "\n",
        "tf.compat.v1.enable_v2_behavior()\n",
        "\n",
        "# Loads the Tensorflow Graphics simplified logo.\n",
        "vertices = tfg_simplified_logo.mesh['vertices'].astype(np.float32)\n",
        "faces = tfg_simplified_logo.mesh['faces']\n",
        "num_vertices = vertices.shape[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pD2ldsUvaP5V"
      },
      "source": [
        "## 1.  Machine Learning\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "UmFx3BqjNWrN"
      },
      "source": [
        "### Model definition\n",
        "Given the 3D position of all the vertices of a known mesh, we would like a network that is capable of predicting the rotation parametrized by a quaternion (4 dimensional vector), and translation (3 dimensional vector) of this mesh with respect to a reference pose. Let's now create a very simple 3-layer fully connected network, and a loss for the task. Note that this model is very simple and definitely not optimal, which is out of scope for this notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "5_EP7GzQb6sn"
      },
      "outputs": [],
      "source": [
        "# Constructs the model.\n",
        "model = keras.Sequential()\n",
        "model.add(layers.Flatten(input_shape=(num_vertices, 3)))\n",
        "model.add(layers.Dense(64, activation=tf.nn.tanh))\n",
        "model.add(layers.Dense(64, activation=tf.nn.relu))\n",
        "model.add(layers.Dense(7))\n",
        "\n",
        "\n",
        "def pose_estimation_loss(y_true, y_pred):\n",
        "  \"\"\"Pose estimation loss used for training.\n",
        "\n",
        "  This loss measures the average of squared distance between some vertices\n",
        "  of the mesh in 'rest pose' and the transformed mesh to which the predicted\n",
        "  inverse pose is applied. Comparing this loss with a regular L2 loss on the\n",
        "  quaternion and translation values is left as exercise to the interested\n",
        "  reader.\n",
        "\n",
        "  Args:\n",
        "    y_true: The ground-truth value.\n",
        "    y_pred: The prediction we want to evaluate the loss for.\n",
        "\n",
        "  Returns:\n",
        "    A scalar value containing the loss described in the description above.\n",
        "  \"\"\"\n",
        "  # y_true.shape : (batch, 7)\n",
        "  y_true_q, y_true_t = tf.split(y_true, (4, 3), axis=-1)\n",
        "  # y_pred.shape : (batch, 7)\n",
        "  y_pred_q, y_pred_t = tf.split(y_pred, (4, 3), axis=-1)\n",
        "\n",
        "  # vertices.shape: (num_vertices, 3)\n",
        "  # corners.shape:(num_vertices, 1, 3)\n",
        "  corners = tf.expand_dims(vertices, axis=1)\n",
        "\n",
        "  # transformed_corners.shape: (num_vertices, batch, 3)\n",
        "  # q and t shapes get pre-pre-padded with 1's following standard broadcast rules.\n",
        "  transformed_corners = quaternion.rotate(corners, y_pred_q) + y_pred_t\n",
        "\n",
        "  # recovered_corners.shape: (num_vertices, batch, 3)\n",
        "  recovered_corners = quaternion.rotate(transformed_corners - y_true_t,\n",
        "                                        quaternion.inverse(y_true_q))\n",
        "\n",
        "  # vertex_error.shape: (num_vertices, batch)\n",
        "  vertex_error = tf.reduce_sum((recovered_corners - corners)**2, axis=-1)\n",
        "\n",
        "  return tf.reduce_mean(vertex_error)\n",
        "\n",
        "\n",
        "optimizer = keras.optimizers.Adam()\n",
        "model.compile(loss=pose_estimation_loss, optimizer=optimizer)\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VPJeqKp3OaE0"
      },
      "source": [
        "### Data generation\n",
        "Now that we have a model defined, we need data to train it. For each sample in the training set, a random 3D rotation and 3D translation are sampled and applied to the vertices of our object. Each training sample consists of all the transformed vertices and the inverse rotation and translation that would allow to revert the rotation and translation applied to the sample."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "jtaNzJKBGi8s"
      },
      "outputs": [],
      "source": [
        "def generate_training_data(num_samples):\n",
        "  # random_angles.shape: (num_samples, 3)\n",
        "  random_angles = np.random.uniform(-np.pi, np.pi,\n",
        "                                    (num_samples, 3)).astype(np.float32)\n",
        "\n",
        "  # random_quaternion.shape: (num_samples, 4)\n",
        "  random_quaternion = quaternion.from_euler(random_angles)\n",
        "\n",
        "  # random_translation.shape: (num_samples, 3)\n",
        "  random_translation = np.random.uniform(-2.0, 2.0,\n",
        "                                         (num_samples, 3)).astype(np.float32)\n",
        "\n",
        "  # data.shape : (num_samples, num_vertices, 3)\n",
        "  data = quaternion.rotate(vertices[tf.newaxis, :, :],\n",
        "                           random_quaternion[:, tf.newaxis, :]\n",
        "                          ) + random_translation[:, tf.newaxis, :]\n",
        "\n",
        "  # target.shape : (num_samples, 4+3)\n",
        "  target = tf.concat((random_quaternion, random_translation), axis=-1)\n",
        "\n",
        "  return np.array(data), np.array(target)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "E3MIRZrQfPJa"
      },
      "outputs": [],
      "source": [
        "num_samples = 10000\n",
        "\n",
        "data, target = generate_training_data(num_samples)\n",
        "\n",
        "print(data.shape)   # (num_samples, num_vertices, 3): the vertices\n",
        "print(target.shape)  # (num_samples, 4+3): the quaternion and translation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "R5yebHu3OOgJ"
      },
      "source": [
        "### Training\n",
        "At this point, everything is in place to start training the neural network!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vquVVhxGC70_"
      },
      "outputs": [],
      "source": [
        "# Callback allowing to display the progression of the training task.\n",
        "class ProgressTracker(keras.callbacks.Callback):\n",
        "\n",
        "  def __init__(self, num_epochs, step=5):\n",
        "    self.num_epochs = num_epochs\n",
        "    self.current_epoch = 0.\n",
        "    self.step = step\n",
        "    self.last_percentage_report = 0\n",
        "\n",
        "  def on_epoch_end(self, batch, logs={}):\n",
        "    self.current_epoch += 1.\n",
        "    training_percentage = int(self.current_epoch * 100.0 / self.num_epochs)\n",
        "    if training_percentage - self.last_percentage_report >= self.step:\n",
        "      print('Training ' + str(\n",
        "          training_percentage) + '% complete. Training loss: ' + str(\n",
        "              logs.get('loss')) + ' | Validation loss: ' + str(\n",
        "                  logs.get('val_loss')))\n",
        "      self.last_percentage_report = training_percentage"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "TtntLH1NDAZV"
      },
      "outputs": [],
      "source": [
        "reduce_lr_callback = keras.callbacks.ReduceLROnPlateau(\n",
        "    monitor='val_loss',\n",
        "    factor=0.5,\n",
        "    patience=10,\n",
        "    verbose=0,\n",
        "    mode='auto',\n",
        "    min_delta=0.0001,\n",
        "    cooldown=0,\n",
        "    min_lr=0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "l0r2AuWsWpb-"
      },
      "outputs": [],
      "source": [
        "# google internal 1",
        "\n",
        "# Everything is now in place to train.\n",
        "EPOCHS = 100\n",
        "pt = ProgressTracker(EPOCHS)\n",
        "history = model.fit(\n",
        "    data,\n",
        "    target,\n",
        "    epochs=EPOCHS,\n",
        "    validation_split=0.2,\n",
        "    verbose=0,\n",
        "    batch_size=32,\n",
        "    callbacks=[reduce_lr_callback, pt])\n",
        "\n",
        "plt.plot(history.history['loss'])\n",
        "plt.plot(history.history['val_loss'])\n",
        "plt.ylim([0, 1])\n",
        "plt.legend(['loss', 'val loss'], loc='upper left')\n",
        "plt.xlabel('Train epoch')\n",
        "_ = plt.ylabel('Error [mean square distance]')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cqYThAKGjhRU"
      },
      "source": [
        "### Testing\n",
        "The network is now trained and ready to use!\n",
        "The displayed results consist of two images. The first image contains the object in 'rest pose' (pastel lemon color) and the rotated and translated object (pastel honeydew color). This effectively allows to observe how **different** the two configurations are. The second image also shows the object in rest pose, but this time the transformation predicted by our trained neural network is applied to the rotated and translated version. Hopefully, the two objects are now in a very **similar** pose.\n",
        "\n",
        "**Note**: press play multiple times to sample different test cases. You will notice that sometimes the scale of the object is off. This comes from the fact that quaternions can encode scale. Using a quaternion of unit norm would result in not changing the scale of the result. We let the interested reader experiment with adding this constraint either in the network architecture, or in the loss function. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bMUOcMsf5sia"
      },
      "source": [
        "Start with a helper function to apply a quaternion and a translation:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "3CuzLFUR2TQc"
      },
      "outputs": [],
      "source": [
        "def transform_points(target_points, quaternion_variable, translation_variable):\n",
        "  return quaternion.rotate(target_points,\n",
        "                           quaternion_variable) + translation_variable"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "EJgJX6DO5zQO"
      },
      "source": [
        "Define a `threejs` viewer for the transformed shape: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ijibJe1p2Kgk"
      },
      "outputs": [],
      "source": [
        "class Viewer(object):\n",
        "\n",
        "  def __init__(self, my_vertices):\n",
        "    my_vertices = np.asarray(my_vertices)\n",
        "    context = threejs_visualization.build_context()\n",
        "    light1 = context.THREE.PointLight.new_object(0x808080)\n",
        "    light1.position.set(10., 10., 10.)\n",
        "    light2 = context.THREE.AmbientLight.new_object(0x808080)\n",
        "    lights = (light1, light2)\n",
        "\n",
        "    material = context.THREE.MeshLambertMaterial.new_object({\n",
        "        'color': 0xfffacd,\n",
        "    })\n",
        "\n",
        "    material_deformed = context.THREE.MeshLambertMaterial.new_object({\n",
        "        'color': 0xf0fff0,\n",
        "    })\n",
        "\n",
        "    camera = threejs_visualization.build_perspective_camera(\n",
        "        field_of_view=30, position=(10.0, 10.0, 10.0))\n",
        "\n",
        "    mesh = {'vertices': vertices, 'faces': faces, 'material': material}\n",
        "    transformed_mesh = {\n",
        "        'vertices': my_vertices,\n",
        "        'faces': faces,\n",
        "        'material': material_deformed\n",
        "    }\n",
        "    geometries = threejs_visualization.triangular_mesh_renderer(\n",
        "        [mesh, transformed_mesh],\n",
        "        lights=lights,\n",
        "        camera=camera,\n",
        "        width=400,\n",
        "        height=400)\n",
        "\n",
        "    self.geometries = geometries\n",
        "\n",
        "  def update(self, transformed_points):\n",
        "    self.geometries[1].getAttribute('position').copyArray(\n",
        "        transformed_points.numpy().ravel().tolist())\n",
        "    self.geometries[1].getAttribute('position').needsUpdate = True"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vt8dw1VE58VH"
      },
      "source": [
        "Define a random rotation and translation:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "dQBZ22DSzaMu"
      },
      "outputs": [],
      "source": [
        "def get_random_transform():\n",
        "  # Forms a random translation\n",
        "  with tf.name_scope('translation_variable'):\n",
        "    random_translation = tf.Variable(\n",
        "        np.random.uniform(-2.0, 2.0, (3,)), dtype=tf.float32)\n",
        "\n",
        "  # Forms a random quaternion\n",
        "  hi = np.pi\n",
        "  lo = -hi\n",
        "  random_angles = np.random.uniform(lo, hi, (3,)).astype(np.float32)\n",
        "  with tf.name_scope('rotation_variable'):\n",
        "    random_quaternion = tf.Variable(quaternion.from_euler(random_angles))\n",
        "\n",
        "  return random_quaternion, random_translation\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_9FfPxsj6S8J"
      },
      "source": [
        "Run the model to predict the transformation parameters, and visualize the result:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "N70Swvie6IWd"
      },
      "outputs": [],
      "source": [
        "random_quaternion, random_translation = get_random_transform()\n",
        "\n",
        "initial_orientation = transform_points(vertices, random_quaternion,\n",
        "                                       random_translation).numpy()\n",
        "viewer = Viewer(initial_orientation)\n",
        "\n",
        "predicted_transformation = model.predict(initial_orientation[tf.newaxis, :, :])\n",
        "\n",
        "predicted_inverse_q = quaternion.inverse(predicted_transformation[0, 0:4])\n",
        "predicted_inverse_t = -predicted_transformation[0, 4:]\n",
        "\n",
        "predicted_aligned = quaternion.rotate(initial_orientation + predicted_inverse_t,\n",
        "                                      predicted_inverse_q)\n",
        "\n",
        "viewer = Viewer(predicted_aligned)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "o6Aut-yJJApf"
      },
      "source": [
        "## 2. Mathematical optimization\n",
        "Here the problem is tackled using mathematical optimization, which is another traditional way to approach the problem of object pose estimation. Given correspondences between the object in 'rest pose' (pastel lemon color) and its rotated and translated counter part (pastel honeydew color), the problem can be formulated as a minimization problem. The loss function can for instance be defined as the sum of Euclidean distances between the corresponding points using the current estimate of the rotation and translation of the transformed object. One can then compute the derivative of the rotation and translation parameters with respect to this loss function, and follow the gradient direction until convergence. The following cell closely follows that procedure, and uses gradient descent to align the two objects. It is worth noting that although the results are good, there are more efficient ways to solve this specific problem. The interested reader is referred to the Kabsch algorithm for further details.\n",
        "\n",
        "**Note**: press play multiple times to sample different test cases. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6vWEO0846b8h"
      },
      "source": [
        "Define the `loss` and `gradient` functions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "b8_TYq-31l3z"
      },
      "outputs": [],
      "source": [
        "def loss(target_points, quaternion_variable, translation_variable):\n",
        "  transformed_points = transform_points(target_points, quaternion_variable,\n",
        "                                        translation_variable)\n",
        "  error = (vertices - transformed_points) / num_vertices\n",
        "  return vector.dot(error, error)\n",
        "\n",
        "\n",
        "def gradient_loss(target_points, quaternion, translation):\n",
        "  with tf.GradientTape() as tape:\n",
        "    loss_value = loss(target_points, quaternion, translation)\n",
        "  return tape.gradient(loss_value, [quaternion, translation])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9y8DLxc26h0g"
      },
      "source": [
        "Create the optimizer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "1YzO2Dpy4oDa"
      },
      "outputs": [],
      "source": [
        "learning_rate = 0.05\n",
        "with tf.name_scope('optimization'):\n",
        "  optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "uPdMbQvO6prf"
      },
      "source": [
        "Initialize the random transformation, run the optimization and animate the result."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vw9RMRPVz0YL"
      },
      "outputs": [],
      "source": [
        "random_quaternion, random_translation = get_random_transform()\n",
        "\n",
        "transformed_points = transform_points(vertices, random_quaternion,\n",
        "                                      random_translation)\n",
        "\n",
        "viewer = Viewer(transformed_points)\n",
        "\n",
        "nb_iterations = 100\n",
        "for it in range(nb_iterations):\n",
        "  gradients_loss = gradient_loss(vertices, random_quaternion,\n",
        "                                 random_translation)\n",
        "  optimizer.apply_gradients(\n",
        "      zip(gradients_loss, (random_quaternion, random_translation)))\n",
        "  transformed_points = transform_points(vertices, random_quaternion,\n",
        "                                        random_translation)\n",
        "\n",
        "  viewer.update(transformed_points)\n",
        "  time.sleep(0.1)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "6dof alignment.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}